A skewed dataset is defined by a dataset that has a class imbalance, this leads to poor or failing spark jobs that often get a OOM (out of memory) error.

When performing a join onto a skewed dataset it's usually the case where there is an imbalance on the key(s) on which the join is performed on. This results in a majority of the data falls onto a single partition, which will take longer to complete than the other partitions.

Some hints to detect skewness is:

  1. The key(s) consist mainly of null values which fall onto a single partition.
  2. There is a subset of values for the key(s) that makeup the high percentage of the total keys which fall onto a single partition.

We go through both these cases and see how we can combat it.

Library Imports

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

Template

spark = (
    SparkSession.builder
    .master("local")
    .appName("Exploring Joins")
    .config("spark.some.config.option", "some-value")
    .getOrCreate()
)

sc = spark.sparkContext

Situation 2: High Frequency Keys

Inital Datasets

customers = spark.createDataFrame([
    (1, "John"), 
    (2, "Bob"),
], ["customer_id", "first_name"])

customers.toPandas()
customer_id first_name
0 1 John
1 2 Bob
orders = spark.createDataFrame([
    (i, 1 if i < 95 else 2, "order #{}".format(i)) for i in range(100) 
], ["id", "customer_id", "order_name"])

orders.toPandas().tail(6)
id customer_id order_name
94 94 1 order #94
95 95 2 order #95
96 96 2 order #96
97 97 2 order #97
98 98 2 order #98
99 99 2 order #99

Option 1: Inner Join

df = customers.join(orders, "customer_id")

df.toPandas().tail(10)
customer_id first_name id order_name
90 1 John 90 order #90
91 1 John 91 order #91
92 1 John 92 order #92
93 1 John 93 order #93
94 1 John 94 order #94
95 2 Bob 95 order #95
96 2 Bob 96 order #96
97 2 Bob 97 order #97
98 2 Bob 98 order #98
99 2 Bob 99 order #99
df.explain()
== Physical Plan ==
*(5) Project [customer_id#122L, first_name#123, id#126L, order_name#128]
+- *(5) SortMergeJoin [customer_id#122L], [customer_id#127L], Inner
   :- *(2) Sort [customer_id#122L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#122L, 200)
   :     +- *(1) Filter isnotnull(customer_id#122L)
   :        +- Scan ExistingRDD[customer_id#122L,first_name#123]
   +- *(4) Sort [customer_id#127L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#127L, 200)
         +- *(3) Filter isnotnull(customer_id#127L)
            +- Scan ExistingRDD[id#126L,customer_id#127L,order_name#128]

What Happened:

  • We want to find what orders each customer made, so we will be joining the customers table to the orders table.
  • When performing the join, we perform a hashpartitioning on customer_id.
  • From our data creation, this means 95% of the data landed onto a single partition.

Results:

  • Similar to the Null Skew case, this means that single task/partition will take a lot longer than the others, and most likely erroring out.

Option 2: Salt the key, then Join

Helper Function

def data_skew_helper(left, right, key, number_of_partitions, how="inner"):
    salt_value = F.lit(F.rand() * number_of_partitions % number_of_partitions).cast('int')
    left = left.withColumn("salt", salt_value)

    salt_col = F.explode(F.array([F.lit(i) for i in range(number_of_partitions)])).alias("salt")
    right = right.select("*",  salt_col)

    return left.join(right, [key, "salt"], how).drop("salt")

Example

num_of_partitions = 5
left = customers

salt_value = F.lit(F.rand() * num_of_partitions % num_of_partitions).cast('int')    
left = left.withColumn("salt", salt_value)

left.toPandas().head(5)
customer_id first_name salt
0 1 John 4
1 2 Bob 3
right = orders

salt_col = F.explode(F.array([F.lit(i) for i in range(num_of_partitions)])).alias("salt")
right = right.select("*",  salt_col)

right.toPandas().head(10)
id customer_id order_name salt
0 0 1 order #0 0
1 0 1 order #0 1
2 0 1 order #0 2
3 0 1 order #0 3
4 0 1 order #0 4
5 1 1 order #1 0
6 1 1 order #1 1
7 1 1 order #1 2
8 1 1 order #1 3
9 1 1 order #1 4
df = left.join(right, ["customer_id", "salt"])

df.orderBy('id').toPandas().tail(10)
customer_id salt first_name id order_name
90 1 4 John 90 order #90
91 1 4 John 91 order #91
92 1 4 John 92 order #92
93 1 4 John 93 order #93
94 1 4 John 94 order #94
95 2 3 Bob 95 order #95
96 2 3 Bob 96 order #96
97 2 3 Bob 97 order #97
98 2 3 Bob 98 order #98
99 2 3 Bob 99 order #99
df.explain()
== Physical Plan ==
*(5) Project [customer_id#122L, salt#136, first_name#123, id#126L, order_name#128]
+- *(5) SortMergeJoin [customer_id#122L, salt#136], [customer_id#127L, salt#141], Inner
   :- *(2) Sort [customer_id#122L ASC NULLS FIRST, salt#136 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(customer_id#122L, salt#136, 200)
   :     +- *(1) Filter (isnotnull(salt#136) && isnotnull(customer_id#122L))
   :        +- *(1) Project [customer_id#122L, first_name#123, cast(((rand(-8040129551223767613) * 5.0) % 5.0) as int) AS salt#136]
   :           +- Scan ExistingRDD[customer_id#122L,first_name#123]
   +- *(4) Sort [customer_id#127L ASC NULLS FIRST, salt#141 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(customer_id#127L, salt#141, 200)
         +- Generate explode([0,1,2,3,4]), [id#126L, customer_id#127L, order_name#128], false, [salt#141]
            +- *(3) Filter isnotnull(customer_id#127L)
               +- Scan ExistingRDD[id#126L,customer_id#127L,order_name#128]

What Happened:

  • We created a new salt column for both datasets.
  • On one of the dataset we duplicate the data so we have a row for each salt value.
  • When performing the join, we perform a hashpartitioning on [customer_id, salt].

Results:

  • When we produce a row per salt value, we have essentially duplicated (num_partitions - 1) * N rows.
  • This created more data, but allowed us to spread the data across more partitions as you can see from hashpartitioning(customer_id, salt).

Summary

All to say:

  • By salting our keys, the skewed dataset gets divided into smaller partitions. Thus removing the skew.
  • Again we will sacrifice more resources in order to get a performance gain or a successful run.
  • We produced more data by creating (num_partitions - 1) * N more data for the right side.

results matching ""

    No results matching ""